{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "这一节主要聊一下马尔可夫链在NLP（自然语言处理）上的一个应用：语言模型，以及如何利用马尔可夫链去生成新的数据\n",
    "### 一.语言模型\n",
    "语言模型的目的其实很简单，就是去判断一句话是不是“人话”，那如何判断呢？可以通过概率来判断，我们可以将一段话看作关于词的序列问题比如“我喜欢吃苹果”可以看作一个序列['我','喜欢','吃','苹果']，于是判断一句话的概率变成了判断一个序列的概率：   \n",
    "\n",
    "$$\n",
    "p(我喜欢吃苹果)=p(我,喜欢,吃,苹果)\n",
    "$$  \n",
    "\n",
    "那如何计算这些概率呢？我们可以去网上搜集大量的文本数据，比如新闻、论坛...然后进行统计即可，我们再接着分析一下，汉语常用词语量在2000左右，如果判断的句子长度为2，那么我们需要$2000^2=4000000$个单词量的文本才能有效的统计，那如果句子长度为3呢，那么就需要$2000^3=8\\cdot10^9$，句子长度为4，就需要$2000^4=16\\cdot 10^{12}$...所以，如果随便给一个句子，那么这个句子的概率极有可能为很低、甚至为0！而我们目的是看一句话是否为“人话”，其实是一个相对概率，比如只要我们的语言模型能做到如下的判断，我们就可以认为我们的模型是可以接受的：   \n",
    "\n",
    "$$\n",
    "p(我,喜欢,吃,苹果)>p(苹果,喜欢,吃,我)\n",
    "$$  \n",
    "\n",
    "那这样我们就可以将对一个句子的全局判断缩小到对其进行一些局部判断，比如“苹果”在“吃”之后的概率应该要大于它在其之前的概率，即：   \n",
    "\n",
    "$$\n",
    "p(苹果\\mid 吃)>p(吃\\mid 苹果)\n",
    "$$  \n",
    "\n",
    "而仅考虑局部的关系，我们自然可以引入马尔可夫链咯，即判断一个句子的概率，我们可以简单看作求马尔可夫链的联合概率：   \n",
    "\n",
    "$$\n",
    "p(x_1,x_2,...,x_n)=p(x_1)\\cdot p(x_2\\mid x_1)\\cdot p(x_3\\mid x_2)\\cdots p(x_n\\mid x_{n-1})\n",
    "$$  \n",
    "\n",
    "这里$x_i,i=1,2,...,n$分别对应一个词，这里每个词只与它的前一个词有关系，这时的模型称为bi-gram模型，我们可以扩展到与前两个词相关，这时便是对应的二阶马尔可夫链，相应的语言模型称为tri-gram模型，接下来实践看看"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 二.利用今日头条文本数据训练语言模型\n",
    "我在[这里>>>](https://github.com/skdjfla/toutiao-text-classfication-dataset)下载了一些今日头条的文本分类数据并作了预处理后截取了前10000句，每句话间用\"\\n\"分割，每个词之间用空格分割"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "text=[]\n",
    "for line in open(\"./data/toutiao_mini.txt\",encoding=\"utf8\"):\n",
    "    text.append(line.strip().split(\" \"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[['京城', '最', '值得', '你', '来场', '文化', '之旅', '的', '博物馆'],\n",
       " ['发酵', '床', '的', '垫料', '种类', '有', '哪些', '？', '哪', '种', '更好', '？']]"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#看看前两行\n",
    "text[:2]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "#接下来构建词典，将汉字映射为整数方便模型训练\n",
    "word2idx={}\n",
    "idx2word={}\n",
    "idx=0\n",
    "for line in text:\n",
    "    for word in line:\n",
    "        if word not in word2idx:\n",
    "            word2idx[word]=idx\n",
    "            idx2word[idx]=word\n",
    "            idx+=1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "#接下来将汉字转换为整数\n",
    "train_data=[]\n",
    "for line in text:\n",
    "    train_data.append([word2idx[word] for word in line])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[[0, 1, 2, 3, 4, 5, 6, 7, 8], [9, 10, 7, 11, 12, 13, 14, 15, 16, 17, 18, 15]]"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#查看前两行\n",
    "train_data[:2]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "#好了，可训练语言模型了\n",
    "import os\n",
    "os.chdir('../')\n",
    "from ml_models.pgm import SimpleMarkovModel"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "D:\\OneDriver\\OneDrive - email.swu.edu.cn\\self\\learning\\self_project\\ML_Notes\\ml_models\\pgm\\simple_markov_model.py:33: RuntimeWarning: invalid value encountered in true_divide\n",
      "  self.P = self.P / np.sum(self.P, axis=0)\n"
     ]
    }
   ],
   "source": [
    "smm=SimpleMarkovModel(status_num=len(word2idx))\n",
    "smm.fit(train_data)\n",
    "del train_data\n",
    "del text"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "模型训练完了，让我们看看概率为0的占比，可以发现绝大部分都很稀疏"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.9620576756873548"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import numpy as np\n",
    "np.sum(smm.P==0)/(smm.P.shape[0]*smm.P.shape[1])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "接下来为概率为0的语句添加一个默认概率"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "smm.P=np.where(smm.P==0,1.0/smm.P.shape[0],smm.P)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "随机说句话看看概率，嗯，情理之中的结果"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "-29.792768141732772\n",
      "-36.43292810682986\n"
     ]
    }
   ],
   "source": [
    "print(smm.predict_log_joint_prob([word2idx[word] for word in [\"我\",\"爱\",\"马云\",\"爸爸\"]]))\n",
    "print(smm.predict_log_joint_prob([word2idx[word] for word in [\"马云\",\"爸爸\",\"爱\",\"我\"]]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 三.利用马尔可夫链生成数据\n",
    "既然马尔可夫链是生成模型，那么我们可以利用其来生成数据，这一节介绍两种方式：greedy search和beam search\n",
    "\n",
    "#### greedy search\n",
    "greedy search即贪婪搜索，每一次都选择使当前时刻概率最大的状态，即   \n",
    "\n",
    "$$\n",
    "S_{next}^*=arg\\max_{S_{next}}p(X_{t}=S_{next}\\mid X_{t-1}=S_{current})\n",
    "$$  \n",
    "\n",
    "其中$S_{current}$已知，所以只需要搜索状态转移概率矩阵就可以得到结果\n",
    "#### beam search\n",
    "greedy search最大的缺点就是会隐藏掉低概率状态后面的高概率状态，而beam search的处理方案就要聪明些，它会同时保留目前top K大的结果，最后从这K个结果里面再选择最优的结果   \n",
    "\n",
    "接下来，我们在代码中新增加这部分功能：   \n",
    "```python\n",
    "    \n",
    "\n",
    "    def generate_status(self, step_times=10, stop_status=None, set_start_status=None, search_type=\"greedy\", beam_num=5):\n",
    "        \"\"\"\n",
    "        生成状态序列，包括greedy search和beam search两种方式\n",
    "        :param step_times: 步长不超过 step_times\n",
    "        :param stop_status: 中止状态列表\n",
    "        :param set_start_status: 人为设置初始状态\n",
    "        :param search_type: 搜索策略，包括greedy和beam\n",
    "        :param beam_num: 只有在search_type=\"beam\"时生效，保留前top个结果\n",
    "        :return:\n",
    "        \"\"\"\n",
    "        if stop_status is None:\n",
    "            stop_status = []\n",
    "        # 初始状态\n",
    "        start_status = np.random.choice(len(self.pi.reshape(-1)),\n",
    "                                        p=self.pi.reshape(-1)) if set_start_status is None else set_start_status\n",
    "        if search_type == \"greedy\":\n",
    "            # 贪婪搜索\n",
    "            rst = [start_status]\n",
    "            for _ in range(0, step_times):\n",
    "                next_status = self.predict_next_step_status(current_status=start_status)\n",
    "                rst.append(next_status)\n",
    "                if next_status in stop_status:\n",
    "                    break\n",
    "        else:\n",
    "            # beam search\n",
    "            rst = [start_status]\n",
    "            top_k_rst = [[start_status]]\n",
    "            top_k_prob = [0.0]\n",
    "            for _ in range(0, step_times):\n",
    "                new_top_k_rst = []\n",
    "                new_top_k_prob = []\n",
    "                for k_index, k_rst in enumerate(top_k_rst):\n",
    "                    k_rst_last_status = k_rst[-1]\n",
    "                    # 获取前k大的idx\n",
    "                    top_k_idx = self.P[:, k_rst_last_status].argsort()[::-1][0:beam_num]\n",
    "                    for top_k_status in top_k_idx:\n",
    "                        new_top_k_rst.append(k_rst + [top_k_status])\n",
    "                        new_top_k_prob.append(top_k_prob[k_index] + np.log(1e-12+self.P[top_k_status, k_rst_last_status]))\n",
    "                # 对所有的beam_num*beam_num个结果排序取前beam_num个结果\n",
    "                top_rst_idx = np.asarray(new_top_k_prob).argsort()[::-1][0:beam_num]\n",
    "                rst = new_top_k_rst[top_rst_idx[0]]\n",
    "                # 更新\n",
    "                top_k_rst = []\n",
    "                top_k_prob = []\n",
    "                for top_idx in top_rst_idx[:beam_num]:\n",
    "                    if new_top_k_rst[top_idx][-1] in stop_status:\n",
    "                        rst = new_top_k_rst[top_idx]\n",
    "                        break\n",
    "                    else:\n",
    "                        top_k_rst.append(new_top_k_rst[top_idx])\n",
    "                        top_k_prob.append(new_top_k_prob[top_idx])\n",
    "        return rst\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['国产', '航母', '战斗群', '浮出', '水面', '，', '你', '怎么', '对', '下联', '？']"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "[idx2word[idx] for idx in smm.generate_status(search_type=\"beam\",stop_status=[word2idx[word] for word in [\"?\",\"!\"]])]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
